// Unless explicitly stated otherwise all files in this repository are licensed
// under the Apache License Version 2.0.
// This product includes software developed at Datadog (https://www.datadoghq.com/).
// Copyright 2024 Datadog, Inc.

package waf

import (
	"context"
	"maps"
	"slices"
	"strings"
	"sync"
	"sync/atomic"

	"github.com/DataDog/dd-trace-go/v2/instrumentation/appsec/dyngo"
	"github.com/DataDog/dd-trace-go/v2/instrumentation/appsec/trace"
	"github.com/DataDog/dd-trace-go/v2/internal/appsec/config"
	"github.com/DataDog/dd-trace-go/v2/internal/appsec/limiter"
	"github.com/DataDog/dd-trace-go/v2/internal/log"
	"github.com/DataDog/dd-trace-go/v2/internal/stacktrace"
	"github.com/DataDog/go-libddwaf/v4"
)

type (
	ContextOperation struct {
		dyngo.Operation
		*trace.ServiceEntrySpanOperation

		// context is an atomic pointer to the current WAF context.
		// Makes sure the calls to context.Run are safe.
		context atomic.Pointer[libddwaf.Context]
		// limiter comes from the WAF feature and is used to limit the number of events as a whole.
		limiter limiter.Limiter
		// events is where we store WAF events received from the WAF over the course of the request.
		events []any
		// stacks is where we store stack traces received from the WAF over the course of the request.
		stacks []*stacktrace.Event
		// derivatives is where we store any span tags generated by the WAF over the course of the request.
		derivatives map[string]any
		// supportedAddresses is the set of addresses supported by the WAF.
		supportedAddresses config.AddressSet
		// metrics the place that manages reporting for the current execution
		metrics *ContextMetrics
		// requestBlocked is used to track if the request has been requestBlocked by the WAF or not.
		requestBlocked bool
		// mu protects the events, stacks, and derivatives, supportedAddresses, eventRulesetVersion slices, and requestBlocked.
		mu sync.Mutex
		// logOnce is used to log a warning once when a request has too many WAF events via the built-in limiter or the max value.
		logOnce sync.Once
	}

	ContextArgs struct{}

	ContextRes struct{}

	// RunEvent is the type of event that should be emitted to child operations to run the WAF
	RunEvent struct {
		libddwaf.RunAddressData
		dyngo.Operation
	}

	// SecurityEvent is a dyngo data event sent when a security event is detected by the WAF
	SecurityEvent struct{}
)

func (ContextArgs) IsArgOf(*ContextOperation)   {}
func (ContextRes) IsResultOf(*ContextOperation) {}

func StartContextOperation(ctx context.Context, span trace.TagSetter) (*ContextOperation, context.Context) {
	entrySpanOp, ctx := trace.StartServiceEntrySpanOperation(ctx, span)
	op := &ContextOperation{
		Operation:                 dyngo.NewOperation(entrySpanOp),
		ServiceEntrySpanOperation: entrySpanOp,
	}
	return op, dyngo.StartAndRegisterOperation(ctx, op, ContextArgs{})
}

func (op *ContextOperation) Finish() {
	dyngo.FinishOperation(op, ContextRes{})
	op.ServiceEntrySpanOperation.Finish()
}

func (op *ContextOperation) SwapContext(ctx *libddwaf.Context) *libddwaf.Context {
	return op.context.Swap(ctx)
}

func (op *ContextOperation) SetLimiter(limiter limiter.Limiter) {
	op.limiter = limiter
}

func (op *ContextOperation) SetMetricsInstance(metrics *ContextMetrics) {
	op.metrics = metrics
}

func (op *ContextOperation) GetMetricsInstance() *ContextMetrics {
	return op.metrics
}

func (op *ContextOperation) SetRequestBlocked() {
	op.mu.Lock()
	defer op.mu.Unlock()
	op.requestBlocked = true
}

// AddEvents adds WAF events to the operation and returns true if the operation has reached the maximum number of events, by the limiter or the max value.
func (op *ContextOperation) AddEvents(events ...any) bool {
	if len(events) == 0 {
		return false
	}

	if !op.limiter.Allow() {
		log.Error("appsec: too many WAF events, stopping further reporting")
		return true
	}

	op.mu.Lock()
	defer op.mu.Unlock()

	const maxWAFEventsPerRequest = 10
	if len(op.events) >= maxWAFEventsPerRequest {
		op.logOnce.Do(func() {
			log.Warn("appsec: ignoring new WAF event due to the maximum number of security events per request was reached")
		})
		return true
	}

	op.events = append(op.events, events...)
	return false
}

func (op *ContextOperation) AddStackTraces(stacks ...*stacktrace.Event) {
	if len(stacks) == 0 {
		return
	}

	op.mu.Lock()
	defer op.mu.Unlock()
	op.stacks = append(op.stacks, stacks...)
}

func (op *ContextOperation) AbsorbDerivatives(derivatives map[string]any) {
	if len(derivatives) == 0 {
		return
	}

	op.mu.Lock()
	defer op.mu.Unlock()
	if op.derivatives == nil {
		op.derivatives = make(map[string]any, len(derivatives))
	}

	for k, v := range derivatives {
		// If the request has been blocked, we don't want to report any derivatives representing the response schema.
		if op.requestBlocked && strings.HasPrefix(k, "_dd.appsec.s.res.") {
			continue
		}

		op.derivatives[k] = v
	}
}

func (op *ContextOperation) Derivatives() map[string]any {
	op.mu.Lock()
	defer op.mu.Unlock()
	return maps.Clone(op.derivatives)
}

func (op *ContextOperation) Events() []any {
	op.mu.Lock()
	defer op.mu.Unlock()
	return slices.Clone(op.events)
}

func (op *ContextOperation) StackTraces() []*stacktrace.Event {
	op.mu.Lock()
	defer op.mu.Unlock()
	return slices.Clone(op.stacks)
}

func (op *ContextOperation) OnEvent(event RunEvent) {
	op.Run(event.Operation, event.RunAddressData)
}

func (op *ContextOperation) SetSupportedAddresses(addrs config.AddressSet) {
	op.supportedAddresses = addrs
}
